{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bOChJSNXtC9g"
   },
   "source": [
    "# 逻辑回归"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OLIxEDq6VhvZ"
   },
   "source": [
    "<img src=\"https://raw.githubusercontent.com/LisonEvf/practicalAI-cn/master/images/logo.png\" width=150>\n",
    "\n",
    "在上一节中，我们看到线性回归可以很好的拟合出一条线后者一个超平面来做出对连续变量的预测。但是在分类问题中我们希望的输出是类别的概率，线性回归就不能做的很好了。\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "VoMq0eFRvugb"
   },
   "source": [
    "# 概述"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qWro5T5qTJJL"
   },
   "source": [
    "<img src=\"https://raw.githubusercontent.com/LisonEvf/practicalAI-cn/master/images/logistic.jpg\" width=270>\n",
    "\n",
    "$ \\hat{y} = \\frac{1}{1 + e^{-XW}} $ \n",
    "\n",
    "*where*:\n",
    "* $\\hat{y}$ = 预测值 | $\\in \\mathbb{R}^{NX1}$ ($N$ 是样本的个数)\n",
    "* $X$ = 输入 | $\\in \\mathbb{R}^{NXD}$ ($D$ 是特征的个数)\n",
    "* $W$ = 权重 | $\\in \\mathbb{R}^{DX1}$ \n",
    "\n",
    "这个是二项式逻辑回归。主要思想是用线性回归的输出值($z=XW$)经过一个sigmoid函数($\\frac{1}{1+e^{-z}}$)来映射到(0, 1)之间。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "YcFvkklZSZr9"
   },
   "source": [
    "当我们有多于两个分类类别，我们就需要使用多项式逻辑回归(softmax分类器)。softmax分类器将会用线性方程($z=XW$)并且归一化它，来产生对应的类别y的概率。\n",
    "\n",
    "$ \\hat{y} = \\frac{e^{XW_y}}{\\sum e^{XW}} $ \n",
    "\n",
    "*where*:\n",
    "* $\\hat{y}$ = 预测值 | $\\in \\mathbb{R}^{NX1}$ ($N$ 是样本的个数)\n",
    "* $X$ = 输入 | $\\in \\mathbb{R}^{NXD}$ ($D$ 是特征的个数)\n",
    "* $W$ = 权重 | $\\in \\mathbb{R}^{DXC}$ ($C$ 是类别的个数)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "T4Y55tpzIjOa"
   },
   "source": [
    "* **目标:**  通过输入值$X$来预测$y$的类别概率。softmax分类器将根据归一化线性输出来计算类别概率。 \n",
    "* **优点:**\n",
    "  * 可以预测与输入对应的类别概率。\n",
    "* **缺点:**\n",
    "  * 因为使用的损失函数是要最小化交叉熵损失，所以对离群点很敏感。(支持向量机([SVMs](https://towardsdatascience.com/support-vector-machine-vs-logistic-regression-94cc2975433f)) 是对处理离群点一个很好的选择).\n",
    "* **其他:** Softmax分类器在神经网络结构中广泛用于最后一层，因为它会计算出类别的概率。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Jq65LZJbSpzd"
   },
   "source": [
    "# 训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-HBPn8zPTQfZ"
   },
   "source": [
    "*步骤*:\n",
    "\n",
    "1. 随机初始化模型权重$W$.\n",
    "2. 将输入值 $X$ 传入模型并且得到logits ($z=XW$). 在logits上使用softmax操作得到独热编码后的类别概率$\\hat{y}$。 比如, 如果有三个类别, 预测出的类别概率可能为[0.3, 0.3, 0.4]. \n",
    "3. 使用损失函数将预测值$\\hat{y}$ (例如[0.3, 0.3, 0.4]])和真实值$y$(例如属于第二个类别应该写作[0, 0, 1])做对比，并且计算出损失值$J$。一个很常用的逻辑回归损失函数是交叉熵函数。 \n",
    "  * $J(\\theta) = - \\sum_i y_i ln (\\hat{y_i}) =  - \\sum_i y_i ln (\\frac{e^{X_iW_y}}{\\sum e^{X_iW}}) $\n",
    "   * $y$ = [0, 0, 1]\n",
    "  * $\\hat{y}$ = [0.3, 0.3, 0.4]]\n",
    "  * $J(\\theta) = - \\sum_i y_i ln (\\hat{y_i}) =  - \\sum_i y_i ln (\\frac{e^{X_iW_y}}{\\sum e^{X_iW}}) = - \\sum_i [0 * ln(0.3) + 0 * ln(0.3) + 1 * ln(0.4)] = -ln(0.4) $\n",
    "  * 简化我们的交叉熵函数: $J(\\theta) = - ln(\\hat{y_i})$ (负的最大似然).\n",
    "  * $J(\\theta) = - ln(\\hat{y_i}) = - ln (\\frac{e^{X_iW_y}}{\\sum_i e^{X_iW}}) $\n",
    "4. 根据模型权重计算损失梯度$J(\\theta)$。让我们假设类别的分类是互斥的(一种输入仅仅对应一个输出类别).\n",
    " * $\\frac{\\partial{J}}{\\partial{W_j}} = \\frac{\\partial{J}}{\\partial{y}}\\frac{\\partial{y}}{\\partial{W_j}} = - \\frac{1}{y}\\frac{\\partial{y}}{\\partial{W_j}} = - \\frac{1}{\\frac{e^{W_yX}}{\\sum e^{XW}}}\\frac{\\sum e^{XW}e^{W_yX}0 - e^{W_yX}e^{W_jX}X}{(\\sum e^{XW})^2} = \\frac{Xe^{W_j}X}{\\sum e^{XW}} = XP$\n",
    "  * $\\frac{\\partial{J}}{\\partial{W_y}} = \\frac{\\partial{J}}{\\partial{y}}\\frac{\\partial{y}}{\\partial{W_y}} = - \\frac{1}{y}\\frac{\\partial{y}}{\\partial{W_y}} = - \\frac{1}{\\frac{e^{W_yX}}{\\sum e^{XW}}}\\frac{\\sum e^{XW}e^{W_yX}X - e^{W_yX}e^{W_yX}X}{(\\sum e^{XW})^2} = \\frac{1}{P}(XP - XP^2) = X(P-1)$\n",
    "5. 使用梯度下降法来对权重做反向传播以更新模型权重。更新后的权重将会使不正确的类别(j)概率大大降低，从而升高正确的类别(y)概率。\n",
    "  * $W_i = W_i - \\alpha\\frac{\\partial{J}}{\\partial{W_i}}$\n",
    "6. 重复2 - 4步直到模型表现最好（也可以说直到损失收敛）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "r_hKrjzdtTgM"
   },
   "source": [
    "# 数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PyccHrQztVEu"
   },
   "source": [
    "我们来加载在第三节课中用到的titanic数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "H385V4VUtWOv"
   },
   "outputs": [],
   "source": [
    "from argparse import Namespace\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import urllib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "pL67TlZO6Zg4"
   },
   "outputs": [],
   "source": [
    "# 参数\n",
    "args = Namespace(\n",
    "    seed=1234,\n",
    "    data_file=\"titanic.csv\",\n",
    "    train_size=0.75,\n",
    "    test_size=0.25,\n",
    "    num_epochs=100,\n",
    ")\n",
    "\n",
    "# 设置随即种子来保证实验结果的可重复性。\n",
    "np.random.seed(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7sp_tSyItf1_"
   },
   "outputs": [],
   "source": [
    "# 从GitHub上加载数据到notebook本地驱动\n",
    "url = \"https://raw.githubusercontent.com/LisonEvf/practicalAI-cn/master/data/titanic.csv\"\n",
    "response = urllib.request.urlopen(url)\n",
    "html = response.read()\n",
    "with open(args.data_file, 'wb') as f:\n",
    "    f.write(html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 272
    },
    "colab_type": "code",
    "id": "7alqmyzXtgE8",
    "outputId": "353702e3-76f7-479d-df7a-5effcc8a7461"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pclass</th>\n",
       "      <th>name</th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>ticket</th>\n",
       "      <th>fare</th>\n",
       "      <th>cabin</th>\n",
       "      <th>embarked</th>\n",
       "      <th>survived</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>Allen, Miss. Elisabeth Walton</td>\n",
       "      <td>female</td>\n",
       "      <td>29.0000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>24160</td>\n",
       "      <td>211.3375</td>\n",
       "      <td>B5</td>\n",
       "      <td>S</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>Allison, Master. Hudson Trevor</td>\n",
       "      <td>male</td>\n",
       "      <td>0.9167</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>Allison, Miss. Helen Loraine</td>\n",
       "      <td>female</td>\n",
       "      <td>2.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>Allison, Mr. Hudson Joshua Creighton</td>\n",
       "      <td>male</td>\n",
       "      <td>30.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>Allison, Mrs. Hudson J C (Bessie Waldo Daniels)</td>\n",
       "      <td>female</td>\n",
       "      <td>25.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>113781</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>C22 C26</td>\n",
       "      <td>S</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   pclass                                             name     sex      age  \\\n",
       "0       1                    Allen, Miss. Elisabeth Walton  female  29.0000   \n",
       "1       1                   Allison, Master. Hudson Trevor    male   0.9167   \n",
       "2       1                     Allison, Miss. Helen Loraine  female   2.0000   \n",
       "3       1             Allison, Mr. Hudson Joshua Creighton    male  30.0000   \n",
       "4       1  Allison, Mrs. Hudson J C (Bessie Waldo Daniels)  female  25.0000   \n",
       "\n",
       "   sibsp  parch  ticket      fare    cabin embarked  survived  \n",
       "0      0      0   24160  211.3375       B5        S         1  \n",
       "1      1      2  113781  151.5500  C22 C26        S         1  \n",
       "2      1      2  113781  151.5500  C22 C26        S         0  \n",
       "3      1      2  113781  151.5500  C22 C26        S         0  \n",
       "4      1      2  113781  151.5500  C22 C26        S         0  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 把CSV文件内容读到DataFrame中\n",
    "df = pd.read_csv(args.data_file, header=0)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "k-5Y4zLIoE6s"
   },
   "source": [
    "# Scikit-learn实现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ILkbyBHQoIwE"
   },
   "source": [
    "**注意**: Scikit-learn中`LogisticRegression`类使用的是坐标下降法（coordinate descent）来做的拟合。然而，我们会使用Scikit-learn中的`SGDClassifier`类来做随机梯度下降。我们使用这个优化方法是因为在未来的几节课程中我们也会使用到它。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "W1MJODStIu8V"
   },
   "outputs": [],
   "source": [
    "# 调包\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kItBIOOCTi6p"
   },
   "outputs": [],
   "source": [
    "# 预处理\n",
    "def preprocess(df):\n",
    "  \n",
    "    # 删除掉含有空值的行\n",
    "    df = df.dropna()\n",
    "\n",
    "    # 删除基于文本的特征 (我们以后的课程将会学习怎么使用它们)\n",
    "    features_to_drop = [\"name\", \"cabin\", \"ticket\"]\n",
    "    df = df.drop(features_to_drop, axis=1)\n",
    "\n",
    "    # pclass, sex, 和 embarked 是类别变量\n",
    "    categorical_features = [\"pclass\",\"embarked\",\"sex\"]\n",
    "    df = pd.get_dummies(df, columns=categorical_features)\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 224
    },
    "colab_type": "code",
    "id": "QwQHDh4xuYTB",
    "outputId": "153ea757-b817-406d-dbde-d1fba88f194b"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>fare</th>\n",
       "      <th>survived</th>\n",
       "      <th>pclass_1</th>\n",
       "      <th>pclass_2</th>\n",
       "      <th>pclass_3</th>\n",
       "      <th>embarked_C</th>\n",
       "      <th>embarked_Q</th>\n",
       "      <th>embarked_S</th>\n",
       "      <th>sex_female</th>\n",
       "      <th>sex_male</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>29.0000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>211.3375</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.9167</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>30.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>25.0000</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>151.5500</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       age  sibsp  parch      fare  survived  pclass_1  pclass_2  pclass_3  \\\n",
       "0  29.0000      0      0  211.3375         1         1         0         0   \n",
       "1   0.9167      1      2  151.5500         1         1         0         0   \n",
       "2   2.0000      1      2  151.5500         0         1         0         0   \n",
       "3  30.0000      1      2  151.5500         0         1         0         0   \n",
       "4  25.0000      1      2  151.5500         0         1         0         0   \n",
       "\n",
       "   embarked_C  embarked_Q  embarked_S  sex_female  sex_male  \n",
       "0           0           0           1           1         0  \n",
       "1           0           0           1           0         1  \n",
       "2           0           0           1           1         0  \n",
       "3           0           0           1           0         1  \n",
       "4           0           0           1           1         0  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 数据预处理\n",
    "df = preprocess(df)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "wsGRZNNiUTqj",
    "outputId": "c9364be7-3cae-487f-9d96-3210b3129199"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train size: 199, test size: 71\n"
     ]
    }
   ],
   "source": [
    "# 划分数据到训练集和测试集\n",
    "mask = np.random.rand(len(df)) < args.train_size\n",
    "train_df = df[mask]\n",
    "test_df = df[~mask]\n",
    "print (\"Train size: {0}, test size: {1}\".format(len(train_df), len(test_df)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "oZKxFmATU95M"
   },
   "source": [
    "**注意**: 如果你有类似标准化的预处理步骤，你需要在划分完训练集和测试集之后再使用它们。这是因为我们不可能从测试集中学到任何有用的信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cLzL_LJd4vQ-"
   },
   "outputs": [],
   "source": [
    "# 分离 X 和 y\n",
    "X_train = train_df.drop([\"survived\"], axis=1)\n",
    "y_train = train_df[\"survived\"]\n",
    "X_test = test_df.drop([\"survived\"], axis=1)\n",
    "y_test = test_df[\"survived\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "id": "AdTYbV472UNJ",
    "outputId": "214a8114-3fd3-407f-cd6e-5f5d07294f50"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean: [-1.78528326e-17  7.14113302e-17 -5.80217058e-17 -5.35584977e-17\n",
      "  3.57056651e-17 -8.92641628e-17  3.57056651e-17 -3.79372692e-17\n",
      "  0.00000000e+00  3.79372692e-17  1.04885391e-16 -6.69481221e-17]\n",
      "std: [1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1.]\n"
     ]
    }
   ],
   "source": [
    "# 标准化训练数据 (mean=0, std=1)\n",
    "X_scaler = StandardScaler().fit(X_train)\n",
    "\n",
    "# 标准化训练和测试数据  (不要标准化标签分类y)\n",
    "standardized_X_train = X_scaler.transform(X_train)\n",
    "standardized_X_test = X_scaler.transform(X_test)\n",
    "\n",
    "# 检查\n",
    "print (\"mean:\", np.mean(standardized_X_train, axis=0)) # mean 应该为 ~0\n",
    "print (\"std:\", np.std(standardized_X_train, axis=0))   # std 应该为 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7-vm9AZm1_f9"
   },
   "outputs": [],
   "source": [
    "# 初始化模型\n",
    "log_reg = SGDClassifier(loss=\"log\", penalty=\"none\", max_iter=args.num_epochs, \n",
    "                        random_state=args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "colab_type": "code",
    "id": "0e8U9NNluYVp",
    "outputId": "c5f22ade-bb8c-479b-d300-98758a82d396"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,\n",
       "       eta0=0.0, fit_intercept=True, l1_ratio=0.15,\n",
       "       learning_rate='optimal', loss='log', max_iter=100, n_iter=None,\n",
       "       n_jobs=1, penalty='none', power_t=0.5, random_state=1234,\n",
       "       shuffle=True, tol=None, verbose=0, warm_start=False)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练\n",
    "log_reg.fit(X=standardized_X_train, y=y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "colab_type": "code",
    "id": "hA7Oz97NAe8A",
    "outputId": "ab8a878a-6012-4727-8cd1-40bc5c69245b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.60319594 0.39680406]\n",
      " [0.00374908 0.99625092]\n",
      " [0.81886302 0.18113698]\n",
      " [0.01082253 0.98917747]\n",
      " [0.93508814 0.06491186]]\n"
     ]
    }
   ],
   "source": [
    "# 概率\n",
    "pred_test = log_reg.predict_proba(standardized_X_test)\n",
    "print (pred_test[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "-jZtTd7F6_ps",
    "outputId": "d2306e4c-88a4-4ac4-9ad5-879fa461617f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 1 0 1 0 1 0 0 1 1 0 0 0 0 1 0 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 1 0 0 1 0\n",
      " 1 0 0 1 1 1 0 1 1 0 0 0 0 1 0 0 1 0 1 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1]\n"
     ]
    }
   ],
   "source": [
    "# 预测 (未标准化)\n",
    "pred_train = log_reg.predict(standardized_X_train) \n",
    "pred_test = log_reg.predict(standardized_X_test)\n",
    "print (pred_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dM7iYW8ANYjy"
   },
   "source": [
    "# 评估指标"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uFXbczqu8Rno"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "sEjansj78Rqe",
    "outputId": "f5bfbe87-12c9-4aa5-fc61-e615ad4e63d4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train acc: 0.77, test acc: 0.82\n"
     ]
    }
   ],
   "source": [
    "# 正确率\n",
    "train_acc = accuracy_score(y_train, pred_train)\n",
    "test_acc = accuracy_score(y_test, pred_test)\n",
    "print (\"train acc: {0:.2f}, test acc: {1:.2f}\".format(train_acc, test_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WijzY-vDNbE9"
   },
   "source": [
    "到目前为止我们用的是正确率作为我们的评价指标来评定模型的好坏程度。但是我们还有很多的评价指标来对模型进行评价。\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/LisonEvf/practicalAI-cn/master/images/metrics.jpg\" width=400>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "80MwyE0yOr-k"
   },
   "source": [
    "评价指标的选择真的要看应用的情景。\n",
    "positive - true, 1, tumor, issue, 等等, negative - false, 0, not tumor, not issue, 等等。\n",
    "\n",
    "$\\text{accuracy}（正确率） = \\frac{TP+TN}{TP+TN+FP+FN}$ \n",
    "\n",
    "$\\text{recall}（召回率）= \\frac{TP}{TP+FN}$ → (有多个正例被我分为正例)\n",
    "\n",
    "$\\text{precision} （精确率）= \\frac{TP}{TP+FP}$ → (在所有我预测为正例的样本下，有多少是对的)\n",
    "\n",
    "$F_1 = 2 * \\frac{\\text{precision }  *  \\text{ recall}}{\\text{precision } + \\text{ recall}}$\n",
    "\n",
    "where: \n",
    "* TP: 将正类预测为正类数\n",
    "* TN: 将负类预测为负类数\n",
    "* FP: 将负类预测为正类数\n",
    "* FN: 将正类预测为负类数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "opmu3hJm9LXA"
   },
   "outputs": [],
   "source": [
    "import itertools\n",
    "from sklearn.metrics import classification_report, confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wAzOL8h29m82"
   },
   "outputs": [],
   "source": [
    "# 绘制混淆矩阵\n",
    "def plot_confusion_matrix(cm, classes):\n",
    "    cmap=plt.cm.Blues\n",
    "    plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
    "    plt.title(\"Confusion Matrix\")\n",
    "    plt.colorbar()\n",
    "    tick_marks = np.arange(len(classes))\n",
    "    plt.xticks(tick_marks, classes, rotation=45)\n",
    "    plt.yticks(tick_marks, classes)\n",
    "    plt.grid(False)\n",
    "\n",
    "    fmt = 'd'\n",
    "    thresh = cm.max() / 2.\n",
    "    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
    "        plt.text(j, i, format(cm[i, j], 'd'),\n",
    "                 horizontalalignment=\"center\",\n",
    "                 color=\"white\" if cm[i, j] > thresh else \"black\")\n",
    "\n",
    "    plt.ylabel('True label')\n",
    "    plt.xlabel('Predicted label')\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 520
    },
    "colab_type": "code",
    "id": "KqUVzahQ-5ic",
    "outputId": "bff8819e-3d5b-45b9-c221-179c873140b1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       0.74      0.91      0.82        32\n",
      "          1       0.91      0.74      0.82        39\n",
      "\n",
      "avg / total       0.83      0.82      0.82        71\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAT8AAAEYCAYAAAAqD/ElAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3XmcXvP5//HXexIiJBISiZ3WErWGWGOL2kKpfY2KpZHavlpUU1qipa1WUUv5Sak9dpXaIqVBECSaIBVRS2pJZFFiCbJcvz/OGW5jZu57Zu65t/N+9nEec9/nnPtzrntirn628zmKCMzMsqau3AGYmZWDk5+ZZZKTn5llkpOfmWWSk5+ZZZKTn5llkpOftYikzpL+LulDSXe0oZxBkh4uZmzlIOlBSYPLHYe1nJNfjZJ0uKQJkj6WNCP9I92uCEUfCPQGekTEQa0tJCJujojdihDP10gaICkk3d1g/ybp/rEFljNc0k35zouIPSLi+laGa2Xk5FeDJJ0KXAL8hiRRrQ78GdinCMWvAUyLiIVFKKu9zAb6S+qRs28wMK1YF1DCfz/VLCK81dAGdAM+Bg5q5pxOJMnx3XS7BOiUHhsAvA2cBswCZgBHp8fOBb4AFqTXOBYYDtyUU/aaQAAd0/dHAa8DHwFvAINy9o/L+Vx/4Dngw/Rn/5xjY4FfA0+m5TwM9Gziu9XHfxVwYrqvQ7rvbGBszrl/At4C5gETge3T/QMbfM/JOXGcn8YxH1g73ffD9PiVwJ055V8APAKo3P9dePvm5v/nqj3bAEsB9zRzzlnA1kBfYBNgS+AXOcdXJEmiq5AkuCskLRcR55DUJm+LiC4RcU1zgUhaBrgU2CMiupIkuEmNnLc8cH96bg/gIuD+BjW3w4GjgV7AksDpzV0buAE4Mn29OzCFJNHneo7kd7A8cAtwh6SlIuKhBt9zk5zP/AA4DugKTG9Q3mnAxpKOkrQ9ye9ucKSZ0CqLk1/t6QHMieabpYOAX0XErIiYTVKj+0HO8QXp8QUR8QBJ7adPK+NZDGwoqXNEzIiIKY2c8z3g1Yi4MSIWRsRIYCqwd845f42IaRExH7idJGk1KSKeApaX1IckCd7QyDk3RcTc9Jp/JKkR5/ue10XElPQzCxqU9ylwBEnyvgk4OSLezlOelYmTX+2ZC/SU1LGZc1bm67WW6em+L8tokDw/Bbq0NJCI+AQ4BPgRMEPS/ZLWKyCe+phWyXk/sxXx3AicBOxEIzVhSadJejkduf6ApLbbM0+ZbzV3MCKeJWnmiyRJW4Vy8qs9TwOfAfs2c867JAMX9Vbnm03CQn0CLJ3zfsXcgxExOiJ2BVYiqc2NKCCe+pjeaWVM9W4ETgAeSGtlX0qbpT8DDgaWi4juJP2Nqg+9iTKbbcJKOpGkBvkucEbrQ7f25uRXYyLiQ5KO/Ssk7StpaUlLSNpD0u/T00YCv5C0gqSe6fl5p3U0YRKwg6TVJXUDfl5/QFJvSd9P+/4+J2k+L2qkjAeAddPpOR0lHQKsD9zXypgAiIg3gB1J+jgb6gosJBkZ7ijpbGDZnOPvAWu2ZERX0rrAeSRN3x8AZ0hqtnlu5ePkV4Mi4iLgVJJBjNkkTbWTgL+lp5wHTABeAF4Enk/3teZaY4Db0rIm8vWEVUcyCPAu8D5JIjqhkTLmAnul584lqTHtFRFzWhNTg7LHRURjtdrRwIMk01+mk9SWc5u09RO450p6Pt910m6Gm4ALImJyRLwKnAncKKlTW76DtQ95IMrMssg1PzPLJCc/M8skJz8zyyQnPzPLpOYmwmaeOnYOLdm13GFk1qbfWb3cIWTa9OlvMmfOHOU/szAdll0jYuH8vOfF/NmjI2Jgsa7bFCe/ZmjJrnTqc3C5w8isJ5+5vNwhZNq2W21e1PJi4fyC/p4+m3RFvrtsisLJz8xKQ4K6DuWO4ktOfmZWOhW0BKKTn5mVjorWhdhmTn5mViJu9ppZFgk3e80si+Rmr5lllJu9ZpY9crPXzDJIuNlrZlkkqKuclFM5kZhZ7atzzc/MssZTXcwsmzzJ2cyyygMeZpZJbvaaWeZ4SSszyyw3e80se3yHh5llkXCz18yyyDU/M8sq9/mZWSa52WtmmSM3e80sq9zsNbOsEVBXVzk1v8qJxMxqmwrcmitCWk3SPyW9LGmKpFPS/cMlvSNpUrrtmS8c1/zMrESE2t7sXQicFhHPS+oKTJQ0Jj12cURcWGhBTn5mVjJtbfZGxAxgRvr6I0kvA6u0KpY2RWJm1gKS8m5AT0kTcrbjmihrTWBT4Jl010mSXpB0raTl8sXi5GdmpVF4n9+ciNg8Z7v6G0VJXYC7gB9HxDzgSmAtoC9JzfCP+cJxs9fMSkKoKKO9kpYgSXw3R8TdABHxXs7xEcB9+cpxzc/MSqbAZm9znxdwDfByRFyUs3+lnNP2A17KF4trfmZWMkUY7d0W+AHwoqRJ6b4zgcMk9QUCeBMYmq8gJz8zKw2B2vjoyogYR+OzAR9oaVlOfmZWEirOPL+icfIzs5Jx8jOz7ClCs7eYnPzMrGRc8zOzTHLys6JatXd3/vLrI+ndY1kWR3DtXU9yxcixbLTuKlx21qEs07kT09+dy9FnXc9Hn3xW7nBr3meffcYuO+3AF59/zsJFC9lv/wP55TnnljusshNys9eKa+GixQy76G4mTX2bLkt34qlbfsYjz0zlyrMPZ9jF9zBu4n84cp+t+cngnfnVn+8vd7g1r1OnTjw05lG6dOnCggUL+O6O27Hb7nuw1dZblzu08lJl1fx8h0cNmDlnHpOmvg3Ax59+ztQ3ZrLyCt1ZZ41ejJv4HwAeHT+VfXfuW84wM0MSXbp0AWDBggUsXLCgov7oy6mtd3gUk5NfjVl9peXp22dVnnvpTf792gz2GrARAPvvuhmr9s670IUVyaJFi9iqX19WX7kX391lV7bcaqtyh1QRVKe8W6lUdfJLV289XdKvJO3Sgs+tKSnvvX/VZpnOSzLywh/y0wvv4qNPPmPo8JsZevAOPHnzGXRZuhNfLFhU7hAzo0OHDjwzcRL/efNtJjz3LFNeqrn/3Fqlkmp+NdHnFxFnlzuGcuvYsY6RFw7htgcncO+jkwGY9uZ77H3CFQCsvXov9th+g3KGmEndu3dnhx0H8PDDD7HBhhuWO5yyKnVyy6fqan6SzpL0iqR/AH3SfddJOjB93U/SY5ImShpdv9pDun+ypKeBE8v3DdrHVecM4pU3ZnLpTY9+uW+F5ZJ+J0kMG7I7I+4cV67wMmX27Nl88MEHAMyfP59HH/kHffqsV+aoKkNdXV3erVSqquYnqR9wKMnqrR2B54GJOceXAC4D9omI2ZIOAc4HjgH+CpwcEY9J+kMz1zgOSFaOXaJLO32T4urf99sM2msrXpz2DuNvHQbAOZePYu3VejH0kB0AuPfRSdxw7/hyhpkZM2fMYMgxg1m0aBGLYzEHHHgwe35vr3KHVRkqp+JXXckP2B64JyI+BZA0qsHxPsCGwJi0et0BmCGpG9A9Ih5Lz7sR2KOxC6Srxl4NULd0ryj6N2gHT016nc6bnvSN/aP5N1eMHFv6gDJuo403ZvyEf5U7jIpUSc3eakt+kKzX1RQBUyJim6/tlLrn+ZyZtTMJ6ipoknO19fk9DuwnqXP62Lq9Gxx/BVhB0jaQNIMlbRARHwAfStouPW9Q6UI2s0T+kV6P9jYhfVbnbcAkYDrwRIPjX6QDH5emTd2OwCXAFOBo4FpJnwKjSxu5mUFS+6sUVZX8ACLifJJBjKaOTwJ2aGT/RGCTnF3Dix6cmTWtwpq9VZf8zKw6CSc/M8soN3vNLHvc7DWzLBKe52dmmVRZ9/Y6+ZlZybjZa2bZIw94mFkGuc/PzDLLzV4zy6QKqvhV3cIGZlat1PZl7CWtJumfkl6WNEXSKen+5SWNkfRq+jPvA2uc/MysJISoq8u/5bEQOC0ivgNsDZwoaX1gGPBIRKwDPJK+b5aTn5mVjJR/a05EzIiI59PXHwEvA6sA+wDXp6ddD+ybLxb3+ZlZyRQ42ttT0oSc91enK6w3LGtNkkdaPAP0jogZkCRISb3yXcTJz8xKogUrOc+JiM2bL0tdgLuAH0fEvNZMoXGz18xKphgrOacPKrsLuDki7k53v5fzpMaVgFn5ynHyM7OSaWufn5LseA3wckRclHNoFDA4fT0YuDdfLG72mllpFGdJq22BHwAvSpqU7jsT+B1wu6Rjgf8CB+UryMnPzEpCRVjVJSLG0fTTf3duSVlNJj9Jy+YJYl5LLmRmVkl3eDRX85tC8qzb3HDr3wewejvGZWY1qEM13NsbEauVMhAzq21SZa3qUtBor6RDJZ2Zvl5VUr/2DcvMalGd8m8liyXfCZIuB3YiGWEB+BS4qj2DMrPaVIR7e4umkNHe/hGxmaR/AUTE+5KWbOe4zKzGiGTEt1IUkvwWSKojGeRAUg9gcbtGZWY1qYLGOwpKfleQ3EqygqRzgYOBc9s1KjOrPSptszafvMkvIm6QNBHYJd11UES81L5hmVmtEVBXQaO9hd7h0QFYQNL09f3AZtYqFZT7ChrtPQsYCawMrArcIunn7R2YmdWW+iWtqmm09wigX0R8CiDpfGAi8Nv2DMzMak+1NXunNzivI/B6+4RjZrWsclJf8wsbXEzSx/cpMEXS6PT9bsC40oRnZrVCVMm9vUD9iO4U4P6c/ePbLxwzq1kFrtRcKs0tbHBNKQMxs9pXQbkvf5+fpLWA84H1gaXq90fEuu0Yl5nVmEpr9hYyZ+864K8kse8B3A7c2o4xmVmNKsYDjIqlkOS3dESMBoiI1yLiFySrvJiZtYgK2EqlkKkun6dPTHpN0o+Ad4C8DwQ2M8slVVazt5Dk9xOgC/B/JH1/3YBj2jMoM6tNVTHaWy8inklffsRXC5qambVYBeW+Zic530O6hl9jImL/domogqzz7ZW5+tZflTuMzFpui5PKHUKmff7Kf4tanqSqafZeXrIozCwTqqLZGxGPlDIQM6t9lbQeXqHr+ZmZtUmlTXJ28jOzkqmg3Fd48pPUKSI+b89gzKx2Vd1DyyVtKelF4NX0/SaSLmv3yMys5nSoy7/lI+laSbMkvZSzb7ikdyRNSrc985VTSP/jpcBewFyAiJiMb28zsxaqf4BRvq0A1wEDG9l/cUT0TbcH8hVSSPKri4jpDfYtKuBzZmZfU1fAlk9EPA68X4xY8nlL0pZASOog6cfAtLZe2MyypX6Sc74N6ClpQs52XIGXOEnSC2mzeLl8JxeS/I4HTgVWB94Dtk73mZm1SDLo0fwGzImIzXO2qwso+kpgLaAvMAP4Y74PFHJv7yzg0AIubmbWrPaa6hIR79W/ljQCuC/fZwpZyXkEjdzjGxGFVkXNzNp1krOklSJiRvp2P756BlGTCpnn94+c10ulBb/V8vDMLNNUnJqfpJHAAJK+wbeBc4ABkvqSVNTeBIbmK6eQZu9tDS58IzCm5SGbWdapCGs1R8Rhjexu8QPXWnN727eANVrxOTPLMAEdK2hlg0L6/P7HV31+dSTza4a1Z1BmVpsq6fa2ZpNf+uyOTUie2wGwOCKaXODUzKwpyR0e5Y7iK80mv4gISfdERL9SBWRmNarCHmBUSAv8WUmbtXskZlbT6mt++bZSae4ZHh0jYiGwHTBE0mvAJyTfISLCCdHMWqSCuvyabfY+C2wG7FuiWMyshgnRoYKyX3PJTwAR8VqJYjGzWlbiZm0+zSW/FSSd2tTBiLioHeIxsxpW4Hp9JdFc8usAdIEiTMk2s8yrpgcYzYgIP7HbzIqmgip++fv8zMyKQVTPc3t3LlkUZlb7VCV9fhHR5jXyzczq1T/AqFL4oeVmVjKVk/qc/MysZERdlYz2mpkVTTUNeJiZFVXVrOdnZlY01TLaa2ZWTG72mllmudlrZplUQYO9Tn5mVhpJs7dysp+Tn5mVTAW1ep38zKxU5NFeM8seN3vNLJvkZq+1gwvOPJmnxz5M9x49ue7vTwIw74P/ce6pxzLznbdYcZXVGH7xtXTt1r3MkdaeVXt35y+/PpLePZZlcQTX3vUkV4wcy0brrsJlZx3KMp07Mf3duRx91vV89Mln5Q63rCqp2VtJcw6tDQbudxi/H3H71/bdMuJPbLb1Dtw8+jk223oHbhlxSZmiq20LFy1m2EV3s+kB57HjkRcy9JAdWO/bK3Ll2Yfzi0vvZYuDf8Oof07mJ4OzvURmsZ7bK+laSbMkvZSzb3lJYyS9mv5cLl85Tn41YpMt+tO129f/vZ985AEG7nsoAAP3PZRx/3igHKHVvJlz5jFp6tsAfPzp50x9YyYrr9CdddboxbiJ/wHg0fFT2XfnvuUMsyKogP8V4DpgYIN9w4BHImId4JH0fbOc/GrY+3Nn06PXigD06LUi/3t/Tpkjqn2rr7Q8ffusynMvvcm/X5vBXgM2AmD/XTdj1d55KyM1r07Ku+UTEY8DDRdb3ge4Pn19PQU8b7yqkp+k70vKm9ELLOvjYpRjVm+Zzksy8sIf8tML7+KjTz5j6PCbGXrwDjx58xl0WboTXyxYVO4Qy6oFzd6ekibkbMcVUHzviJgBkP7sle8DFTfgIaljRCxs7FhEjAJGlTikqrV8jxWYO2smPXqtyNxZM1lu+Z7lDqlmdexYx8gLh3DbgxO499HJAEx78z32PuEKANZevRd7bL9BOUOsAAU3a+dExObtHU271fwkLSPpfkmTJb0k6RBJb0rqmR7fXNLY9PVwSVdLehi4QdIzkjbIKWuspH6SjpJ0uaRuaVl16fGlJb0laQlJa0l6SNJESU9IWi8951uSnpb0nKRft9f3riT9v7sHD/3tVgAe+tutbLvznmWOqHZddc4gXnljJpfe9OiX+1ZYrguQ3Mw/bMjujLhzXLnCqwwF1PracO/ve5JWAkh/zsr3gfZs9g4E3o2ITSJiQ+ChPOf3A/aJiMOBW4GD4csvsnJETKw/MSI+BCYDO6a79gZGR8QC4Grg5IjoB5wO/Dk950/AlRGxBTCzqSAkHVdf3f7wf3Nb9o3L6FenDuHEwwby1hv/4cAdN+T+O2/i8CGnMPGpsQzafQsmPjWWw4ecUu4wa1L/vt9m0F5bseMW6zL+1mGMv3UYu2+3PgcP3JwX/nY2k+/5JTNmf8gN944vd6hlVf8Ao7b2+TVhFDA4fT0YuDffB9qz2fsicKGkC4D7IuKJPMvZjIqI+enr24ExwDkkSfCORs6/DTgE+CdwKPBnSV2A/sAdOdfqlP7cFjggfX0jcEFjQUTE1SQJlD4b9o0837FinH3RiEb3X3Td30ocSfY8Nel1Om960jf2j+bfXDFybOkDqmDFmOUnaSQwgKRv8G2SPPE74HZJxwL/BQ7KV067Jb+ImCapH7An8Nu0SbuQr2qbSzX4yCc5n31H0lxJG5MkuKGNXGJUWu7yJLXGR4FlgA8ioqk5BVWTzMxqUTHW84uIw5o41KKJlO3Z57cy8GlE3ARcCGwGvEmSqOCrWlhTbgXOALpFxIsND0bEx8CzJM3Z+yJiUUTMA96QdFAagyRtkn7kSZIaIsCgVn8xM2s1Kf9WKu3Z57cR8KykScBZwHnAucCfJD0B5Bv3v5MkWd3ezDm3AUekP+sNAo6VNBmYQjL/B+AU4ERJzwHdWvhdzKwIVMBWKu3Z7B0NjG7k0LqNnDu8kX3v0SC+iLiOZHZ3/fs7afD7iog3+Obs7/r92+Ts+l0z4ZtZkQkvY29mWeRVXcwsqyoo9zn5mVmpyM1eM8umCsp9Tn5mVhqlHs3Nx8nPzErGzV4zy6QKyn1OfmZWOhWU+5z8zKxE5GavmWVQcodHuaP4ipOfmZVMBeU+Jz8zKx03e80skyoo9zn5mVnpVFDuc/Izs9LwklZmlk1e0srMsqqCcp+Tn5mVipe0MrOMqqDc5+RnZqXhJa3MLLPc7DWzTKqg3OfkZ2alU0G5z8nPzErES1qZWRYVa0krSW8CHwGLgIURsXlrynHyM7OSKWK9b6eImNOWApz8zKxk6iqo2VtX7gDMLENUwJZfAA9LmijpuNaG4pqfmZVMgfW+npIm5Ly/OiKuznm/bUS8K6kXMEbS1Ih4vKWxOPmZWUlIBTd75zQ3iBER76Y/Z0m6B9gSaHHyc7PXzEqnjc1eSctI6lr/GtgNeKk1objmZ2YlU4Thjt7APel8wY7ALRHxUGsKcvIzsxJRm0d7I+J1YJNiROPkZ2YlUWnP7XWfn5llkmt+ZlYylTTJ2cnPzErDDzAysyzySs5mllle0srMMqmCcp+Tn5mVTgXlPic/MyudSmr2KiLKHUPFkjQbmF7uONqgJ9CmBR+tTar9979GRKxQrMIkPUTyO8lnTkQMLNZ1m4zHya92SZrQ2iW+re38+69svsPDzDLJyc/MMsnJr7Zdnf8Ua0f+/Vcw9/mZWSa55mdmmeTkZ2aZ5OSXIZL8712BlM78VSXNAM4A/zFkgKRNACJisRNgRVoHICLCCbB0/IdQ4yQtAfxc0v3gBFhJlOgEPCLpCnACLCX/EdQwSXURsQA4Avhc0vXgBFhB6iLic2A9YG9J54ATYKn4D6CGRcTi9OVBwH+B/pJuqj/mBFheEbEofbkFMIqkhn5+eswJsJ35P/4aJ2lfYDhwJTA02aXbwAmwEkg6DLgCuBzYHxgk6ffgBNjevKRVjZGk+PrM9cXArRHxiqTXgWnA3ZLujIgDc2qHVh51wI0RMRWYKmkn4DlJRMQZ4bsQ2o3/X7+G5Ca+nBrDO8BxkjaLiAUR8TbwGLCUpFXKFWsWNVGL+wA4uP5NRLwB3AIcLGkF1/zaj2t+NSQn8R0PbCrpY+A64CfAHZJOB3oBawODI2JuuWLNopx/nxOBVYGuwBnAeEnPASeS9P91ADaPiGpeC7Di+d7eGiPpRyQDHKcCFwEvR8RJkg4HdgWWAc6LiBfKGGZmSTqBpG/veODvwF0RcVY60LEs0Ac4LSJeLGOYmeDkV+UkrQMsGRFT0vc/I1lNZBCwF7A3SffG4ohYkE5/cT9fidR3ReT8/C1wAXAssBNwIPA5fDnA0Smd/mLtzH1+VUzSt0jm8L0mafl0dw9gIrBzRAxM5/kdBRwrqaMTX2nlDFjskP7sBdwP9AP2j4jPgJOAY9L+vS9KH2U2OflVKUmrAkOA94BNgTMlbUDS1J0GvJaedwxwCvDPiFhYpnAzLb3LZoSk7wO/J+lzHRsRX0gaTNIEfixS5Yw1S9zsrVJpLWEwsC7wIcldAm8CDwKfAJcA80hqGkMj4t/lidQAJB0IbBAR50oaAFwFjCe5r/e4+m4LKx2P9lahnP6jxUBfktHB0cAGwO7AnRGxq6QlgU4R8VEZw80USX2AmRHxoaQDgDERMQ+YDJwqaXREjJW0PUlf3xIedS8PN3urUJr4BgEnA8NImrndgFeAlYEhkjaNiC+c+Eon7XfdA+ggqSNJP99dkk4FFpE0eX8mqXtEzI6IeU585ePkV736ALenU1ZOA/4HbAu8DSwBvFXG2DIpIt4nuY1wJeC3wOnAmcAC4BFgK2D19LiVmZNf9Xoe2FbSBmkN7xKS/r0FwHBPkC2b5UiSWxeSxDctIi4DDgAWAkvhEd2K4AGPKiWpO/BTIIBHgc7Aj0g6z2eUM7asktQPOA/4PrA+yYDU58DFETErXbtP6fQWKzMnvyomaWWSuwX2J6lV+M6AEmpkEQkk3UFyV83Z6aDG94ClgXPdv1dZnPxqgKRlSP4tPy53LFmUzrmcFxHz0kcGHAMMi4j5knYDtgMui4jZZQ3UvsbJz6yFGqyecwTwf8AYYDpwDfAAyXSjEek5nSNifrnitcY5+Zm1UjrdaCfgBpLBw9+RJL7OJFNe9omI6eWL0Jrj0V6zVpDUn2QE95KIeDwixpIkwo9IVmdZk+ROG6tQTn5mrbMRsAZwUHonDRExPyIujogTgT6eblTZfHubWQtI2hPYMiKGp4vFbgPsL+mOiFgkqUP6YKJZ5Y3U8nHyM2tGI9NZZgFbSzojIn6fzt3bhuSxADfWP5HNq7NUPjd7zZqRM6pbv17iv4Cfk9xdc0ZEXEtyb/X6JKtkW5XwaK9ZIxpMZ9kJ+Cuwb0RMShct6EeybNjdEfEHSd0i4sMyhmwt5JqfWQMNEt8JwCrArcCNkjaOiIUR8QwwFRggaXknvurjPj+zBnIS31CSuzX2iYibJM0DrkmXqNoAWBI4Kl3NxaqMk59ZIyTVT1Q+C1iQJsIlSe7TPZDk0QEn+pa16uU+P7MmSDqOZKWct0gGNaaTPG/3N8AC37JW3VzzM2vaDSSju69FxPvp7WwHAAud+Kqfa35meUiqA44GfgwcFhEvlTkkKwLX/MzyWwpYDBwcES+XOxgrDtf8zArQ2MKlVt2c/MwskzzJ2cwyycnPzDLJyc/MMsnJz8wyycnPviRpkaRJkl6SdIekpdtQ1gBJ96Wvvy9pWDPndk8XEGjpNYZLOr3Q/Q3OuU7SgS241pqSPL+vhjj5Wa75EdE3IjYEviC5tetLSrT4v5mIGBURv2vmlO5Ai5OfWVs4+VlTngDWTms8L0v6M/A8sJqk3SQ9Len5tIbYBUDSQElTJY0jeZA66f6jJF2evu4t6R5Jk9OtP8lTz9ZKa51/SM/7qaTnJL0g6dycss6S9IqkfwB98n0JSUPSciZLuqtBbXYXSU9ImiZpr/T8DpL+kHPtoW39RVplcvKzb0gX69wDeDHd1Qe4ISI2JXki2S+AXSJiM2ACcKqkpYARwN7A9sCKTRR/KfBYRGwCbAZMAYaR3D/bNyJ+mj7oex1gS6Av0E/SDpL6AYeSrKiyP7BFAV/n7ojYIr3ey8CxOcfWBHYEvgdclX6HY4EPI2KLtPwhkr5VwHWsyvj2NsvVWdKk9PUTJA/gXhmYHhHj0/1bkyzZ/qQkSJZ5ehpYD3gjIl4FkHQTcFwj1/gucCRA+ryLDyUt1+Cc3dLtX+n7LiTJsCtwT0R8ml5jVAHfaUNJ55E0rbsAo3OO3R4Ri4FXJb2efofdgI1z+gO7pdeeVsC1rIo4+Vmu+RHRN3dHmuBynz8rYExEHNbgvL5AsW4XEvDbiPh/Da7x41Zc4zqS5ecnSzoKGJBzrGGJ1si5AAABFElEQVRZkV775IjITZJIWrOF17UK52avtdR4kof3rA0gaWlJ65Is6f4tSWul5x3WxOcfAY5PP9tB0rIkD/rumnPOaOCYnL7EVST1Ah4H9pPUWVJXkiZ2Pl2BGZKWAAY1OHaQpLo05m8Dr6TXPj49H0nrSvKDiWqQa37WIhExO61BjUwf2wjwi4iYli7+eb+kOcA4YMNGijgFuFrSscAi4PiIeFrSk+lUkgfTfr/vAE+nNc+PgSMi4nlJtwGTSBYWfaKAkH8JPJOe/yJfT7KvAI8BvYEfRcRnkv5C0hf4vJKLzwb2Ley3Y9XECxuYWSa52WtmmeTkZ2aZ5ORnZpnk5GdmmeTkZ2aZ5ORnZpnk5GdmmfT/ASD5obfWzV1hAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 混淆矩阵\n",
    "cm = confusion_matrix(y_test, pred_test)\n",
    "plot_confusion_matrix(cm=cm, classes=[\"died\", \"survived\"])\n",
    "print (classification_report(y_test, pred_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "iMk7tN1h98x9"
   },
   "source": [
    "当我们有大于两个标签（二分类）的时候，我们可以选择在微观/宏观层面计算评估指标（每个clas标签）、权重等。 更详细内容可以参考[offical docs](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "9v6zc1_1PWnz"
   },
   "source": [
    "# 推论"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Zl9euDuMPYTN"
   },
   "source": [
    "现在我们来看看你是否会在Titanic中存活下来"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 80
    },
    "colab_type": "code",
    "id": "kX9428-EPUzx",
    "outputId": "ef100af7-9861-4900-e9c7-ed6d93c69069"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>cabin</th>\n",
       "      <th>embarked</th>\n",
       "      <th>fare</th>\n",
       "      <th>name</th>\n",
       "      <th>parch</th>\n",
       "      <th>pclass</th>\n",
       "      <th>sex</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>ticket</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>24</td>\n",
       "      <td>E</td>\n",
       "      <td>C</td>\n",
       "      <td>100</td>\n",
       "      <td>Goku Mohandas</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>male</td>\n",
       "      <td>1</td>\n",
       "      <td>E44</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age cabin embarked  fare           name  parch  pclass   sex  sibsp ticket\n",
       "0   24     E        C   100  Goku Mohandas      2       1  male      1    E44"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 输入你自己的信息\n",
    "X_infer = pd.DataFrame([{\"name\": \"Goku Mohandas\", \"cabin\": \"E\", \"ticket\": \"E44\", \n",
    "                         \"pclass\": 1, \"age\": 24, \"sibsp\": 1, \"parch\": 2, \n",
    "                         \"fare\": 100, \"embarked\": \"C\", \"sex\": \"male\"}])\n",
    "X_infer.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 80
    },
    "colab_type": "code",
    "id": "c6OAAQoaWxAb",
    "outputId": "85eb1c6d-6f53-4bd4-bcc3-90d9ebca74c8"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>fare</th>\n",
       "      <th>parch</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>pclass_1</th>\n",
       "      <th>embarked_C</th>\n",
       "      <th>sex_male</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>24</td>\n",
       "      <td>100</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  fare  parch  sibsp  pclass_1  embarked_C  sex_male\n",
       "0   24   100      2      1         1           1         1"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 进行预处理\n",
    "X_infer = preprocess(X_infer)\n",
    "X_infer.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 80
    },
    "colab_type": "code",
    "id": "48sj5A0mX5Yw",
    "outputId": "d9571238-70ab-427d-f80c-7b13b00efc95"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>sibsp</th>\n",
       "      <th>parch</th>\n",
       "      <th>fare</th>\n",
       "      <th>pclass_1</th>\n",
       "      <th>pclass_2</th>\n",
       "      <th>pclass_3</th>\n",
       "      <th>embarked_C</th>\n",
       "      <th>embarked_Q</th>\n",
       "      <th>embarked_S</th>\n",
       "      <th>sex_female</th>\n",
       "      <th>sex_male</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>24</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>100</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  sibsp  parch  fare  pclass_1  pclass_2  pclass_3  embarked_C  \\\n",
       "0   24      1      2   100         1         0         0           1   \n",
       "\n",
       "   embarked_Q  embarked_S  sex_female  sex_male  \n",
       "0           0           0           0         1  "
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 添加缺失列向量\n",
    "missing_features = set(X_test.columns) - set(X_infer.columns)\n",
    "for feature in missing_features:\n",
    "    X_infer[feature] = 0\n",
    "\n",
    "# 重整title\n",
    "X_infer = X_infer[X_train.columns]\n",
    "X_infer.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rP_i8w9IXFiM"
   },
   "outputs": [],
   "source": [
    "# 标准化\n",
    "standardized_X_infer = X_scaler.transform(X_infer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "7O5PbOAvXTzF",
    "outputId": "f1c3597e-1676-476f-e970-168e5c3fca6c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looks like I would've survived with about 57% probability on the Titanic expedition!\n"
     ]
    }
   ],
   "source": [
    "# 预测\n",
    "y_infer = log_reg.predict_proba(standardized_X_infer)\n",
    "classes = {0: \"died\", 1: \"survived\"}\n",
    "_class = np.argmax(y_infer)\n",
    "print (\"Looks like I would've {0} with about {1:.0f}% probability on the Titanic expedition!\".format(\n",
    "    classes[_class], y_infer[0][_class]*100.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8PLPFFP67tvL"
   },
   "source": [
    "# 可解释性"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jv6LKNXO7uch"
   },
   "source": [
    "哪些特征是最有影响力的？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "KTSpxbwy7ugl",
    "outputId": "b37bf39c-f35d-4793-a479-6e61179fc5e5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.02155712  0.39758992  0.78341184 -0.0070509  -2.71953415  2.01530102\n",
      "   3.50708962  0.11008796  0.         -0.11008796  2.94675085 -2.94675085]]\n",
      "[5.10843738]\n"
     ]
    }
   ],
   "source": [
    "# 未标准化系数\n",
    "coef = log_reg.coef_ / X_scaler.scale_\n",
    "intercept = log_reg.intercept_ - np.sum((coef * X_scaler.mean_))\n",
    "print (coef)\n",
    "print (intercept)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xJgiIupyE0Hd"
   },
   "source": [
    "正系数表示与阳性类的相关性（1 = 存活），负系数表示与阴性类的相关性（0 = 死亡）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "RKRB0er2C5l-",
    "outputId": "39ad0cf3-13b1-4aa8-9a6b-4456b8975a39"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Features correlated with death: ['sex_male', 'pclass_1', 'embarked_S']\n",
      "Features correlated with survival: ['pclass_2', 'sex_female', 'pclass_3']\n"
     ]
    }
   ],
   "source": [
    "indices = np.argsort(coef)\n",
    "features = list(X_train.columns)\n",
    "print (\"Features correlated with death:\", [features[i] for i in indices[0][:3]])\n",
    "print (\"Features correlated with survival:\", [features[i] for i in indices[0][-3:]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RhhFw3Kg-4aL"
   },
   "source": [
    "### 非标准化系数的证明:\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ER0HFHXj-4h8"
   },
   "source": [
    "注意我们的X和y都已经标准化了。\n",
    "\n",
    "$\\mathbb{E}[y] = W_0 + \\sum_{j=1}^{k}W_jz_j$\n",
    "\n",
    "$z_j = \\frac{x_j - \\bar{x}_j}{\\sigma_j}$\n",
    "\n",
    "$ \\hat{y} = \\hat{W_0} + \\sum_{j=1}^{k}\\hat{W_j}z_j $\n",
    "\n",
    "$\\hat{y} = \\hat{W_0} + \\sum_{j=1}^{k} \\hat{W}_j (\\frac{x_j - \\bar{x}_j}{\\sigma_j}) $\n",
    "\n",
    "$\\hat{y} = (\\hat{W_0} - \\sum_{j=1}^{k} \\hat{W}_j\\frac{\\bar{x}_j}{\\sigma_j}) +  \\sum_{j=1}^{k} (\\frac{\\hat{w}_j}{\\sigma_j})x_j$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5yBZLVHwGKSj"
   },
   "source": [
    "# K折交叉验证"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fHyLTMAAGJ_x"
   },
   "source": [
    "交叉验证是一个重采样的模型评估方法。与其我们在一开始就仅仅划分一次训练集和验证集，我们用交叉验证来划分k(通常 k=5 或者 10)次不同的训练集和验证集。\n",
    "\n",
    "步骤:\n",
    "1.   随机打乱训练数据集*train*。\n",
    "2.   将数据集分割成不同的k个片段。\n",
    "3.   在k次的每次循环中选择一个片段来当作验证集，其余的所有片段当成训练集。\n",
    "4.   重复这个过程使每个片段都有可能成为训练集或者测试集的一部分。\n",
    "5.   随机初始化权重来训练模型。\n",
    "6.   在k个循环中每次都要重新初始化模型，但是权重要保持相同的随机初始化，然后再在验证集中进行验证。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "6XB6X1b0KcvJ"
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import cross_val_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "UIqKmAEtVWMg"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scores: [0.66666667 0.7        0.7        0.4        0.7        0.7\n",
      " 0.85       0.7        0.68421053 0.78947368]\n",
      "Mean: 0.6890350877192982\n",
      "Standard Deviation: 0.10984701440790533\n"
     ]
    }
   ],
   "source": [
    "#  K折交叉验证\n",
    "log_reg = SGDClassifier(loss=\"log\", penalty=\"none\", max_iter=args.num_epochs)\n",
    "scores = cross_val_score(log_reg, standardized_X_train, y_train, cv=10, scoring=\"accuracy\")\n",
    "print(\"Scores:\", scores)\n",
    "print(\"Mean:\", scores.mean())\n",
    "print(\"Standard Deviation:\", scores.std())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "L0aQUomQoni1"
   },
   "source": [
    "# TODO"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jCpKSu53EA9-"
   },
   "source": [
    "- interaction terms\n",
    "- interpreting odds ratio\n",
    "- simple example with coordinate descent method (sklearn.linear_model.LogisticRegression)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "05_Logistic_Regression",
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
